package com.xiam.consia.ml_new.classifiers;

import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Ordering;
import com.google.common.util.concurrent.ListeningExecutorService;
import com.xiam.consia.algs.predict.property.PropertyManager;
import com.xiam.consia.data.constants.PropertyConstants;
import com.xiam.consia.featurecapture.store.FeatureSample;
import com.xiam.consia.featurecapture.store.FeatureSampleStore;
import com.xiam.consia.featurecapture.store.attributes.AttributeStore;
import com.xiam.consia.ml.classifiers.ClassifierConstants;
import com.xiam.consia.ml_new.attributeselection.AttributeSelection;
import com.xiam.consia.ml_new.data.ProbResults;
import com.xiam.consia.ml_new.data.builder.ModelThreadPool;
import com.xiam.consia.ml_new.tree.builder.TreeBuilder;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.TimeUnit;
import java.util.zip.ZipEntry;
import java.util.zip.ZipFile;
import java.util.zip.ZipOutputStream;

/* loaded from: classes.dex */
public class RandomForest extends Classifier<FeatureSample, FeatureSampleStore> {
    public static final String PREDICT_ML_DATA_MODEL_TREE_DESERIALIZATION_THREADS = "PREDICT_ML_DATA_MODEL_TREE_DESERIALIZATION_THREADS";
    private final AttributeStore.AttributeNameSerialiser attributeNameSerialiser;
    private long bagPercentSize;
    private String class1Label;
    private BigDecimal class1Weight;
    private String class2Label;
    private BigDecimal class2Weight;
    private int leafCount;
    private long maxNumRecordsPerClass;
    private final long maxTreeDepth;
    private long minNumRecordsPerClass;
    private final long modelLoadTimeoutMins;
    private int nodeCount;
    private int numDeserialisationThreads;
    private long numTrees;
    private final ClassifierConstants.PredictionType predictionType;
    private Random randomGenerator;
    private long randomSeed;
    private boolean traceTreeIteration;
    private final List<RandomForestTree> trees;
    private boolean useRandomSplitting;
    private static final Ordering<ProbResults> PROB_RESULTS_ORDERING = Ordering.from(new ProbabilityResultsComparator());
    private static final Integer ONE = 1;
    private static final BigDecimal MINUS_ONE = BigDecimal.valueOf(-1.0d);

    public RandomForest(PropertyManager propertyManager, ClassifierConstants.PredictionType predictionType, String str, AttributeStore.AttributeNameSerialiser attributeNameSerialiser) {
        super(str, predictionType.getClassCount());
        this.traceTreeIteration = false;
        this.trees = Lists.newArrayList();
        this.numTrees = 5L;
        this.class1Weight = BigDecimal.ZERO;
        this.class2Weight = BigDecimal.ZERO;
        this.useRandomSplitting = true;
        this.predictionType = predictionType;
        this.attributeNameSerialiser = attributeNameSerialiser;
        this.maxTreeDepth = TreeBuilder.getMaxTreeDepth(propertyManager, predictionType);
        this.numDeserialisationThreads = ModelThreadPool.getThreadsByCoreProperty("PREDICT_ML_DATA_MODEL_TREE_DESERIALIZATION_THREADS", propertyManager);
        this.modelLoadTimeoutMins = propertyManager.getLongProperty(PropertyConstants.MODEL_LOAD_TIMEOUT_MINS);
        loadProperties(predictionType, propertyManager);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void addTreeToForest(RandomForestTree randomForestTree) {
        this.trees.add(randomForestTree);
        this.nodeCount += randomForestTree.getTree().nodeCount();
        this.leafCount += randomForestTree.getTree().leafCount();
    }

    private static BigDecimal asBigDecimal(String str) {
        return BigDecimal.valueOf(Double.parseDouble(str));
    }

    private static double calcMean(List<Double> list) {
        double d = 0.0d;
        Iterator<Double> it = list.iterator();
        while (true) {
            double d2 = d;
            if (!it.hasNext()) {
                return d2 / list.size();
            }
            d = it.next().doubleValue() + d2;
        }
    }

    private Map<String, BigDecimal> calculateClassVotes(FeatureSample featureSample) {
        HashMap newHashMap = Maps.newHashMap();
        newHashMap.put(this.class1Label, BigDecimal.ZERO);
        newHashMap.put(this.class2Label, BigDecimal.ZERO);
        double doubleValue = this.class1Weight.doubleValue();
        double doubleValue2 = this.class2Weight.doubleValue();
        Iterator<RandomForestTree> it = this.trees.iterator();
        while (it.hasNext()) {
            ProbResults probResults = it.next().classify(featureSample, this.numberOfClasses, doubleValue, doubleValue2).get(0);
            BigDecimal valueOf = BigDecimal.valueOf(probResults.getPredictionProbability());
            BigDecimal subtract = BigDecimal.ONE.subtract(valueOf);
            if (probResults.getMostLikelyClass().equalsIgnoreCase(this.class1Label)) {
                newHashMap.put(this.class1Label, ((BigDecimal) newHashMap.get(this.class1Label)).add(valueOf));
                newHashMap.put(this.class2Label, ((BigDecimal) newHashMap.get(this.class2Label)).add(subtract));
            } else {
                newHashMap.put(this.class2Label, ((BigDecimal) newHashMap.get(this.class2Label)).add(valueOf));
                newHashMap.put(this.class1Label, ((BigDecimal) newHashMap.get(this.class1Label)).add(subtract));
            }
        }
        return newHashMap;
    }

    private FeatureSampleStore getCurrentBag(FeatureSampleStore featureSampleStore) {
        return this.predictionType.isBinaryClassification() ? Bagger.createDynamicBag(featureSampleStore, this.bagPercentSize, this.minNumRecordsPerClass, this.maxNumRecordsPerClass, this.randomGenerator) : Bagger.createBag(featureSampleStore.getRandomRecordSupplier(this.randomGenerator), (int) ((this.bagPercentSize / 100.0d) * featureSampleStore.getNumRecords()));
    }

    private ProbResults majorityVote(FeatureSample featureSample) {
        int i;
        String str;
        HashMap newHashMap = Maps.newHashMap();
        double doubleValue = this.class1Weight.doubleValue();
        double doubleValue2 = this.class2Weight.doubleValue();
        Iterator<RandomForestTree> it = this.trees.iterator();
        while (it.hasNext()) {
            String mostLikelyClass = it.next().classify(featureSample, this.numberOfClasses, doubleValue, doubleValue2).get(0).getMostLikelyClass();
            if (newHashMap.get(mostLikelyClass) != null) {
                newHashMap.put(mostLikelyClass, Integer.valueOf(((Integer) newHashMap.get(mostLikelyClass)).intValue() + 1));
            } else {
                newHashMap.put(mostLikelyClass, ONE);
            }
        }
        String str2 = "";
        int i2 = 0;
        for (Map.Entry entry : newHashMap.entrySet()) {
            if (((Integer) entry.getValue()).intValue() > i2) {
                int intValue = ((Integer) entry.getValue()).intValue();
                str = (String) entry.getKey();
                i = intValue;
            } else {
                i = i2;
                str = str2;
            }
            i2 = i;
            str2 = str;
        }
        return new ProbResults(str2, i2 / this.trees.size());
    }

    private List<ProbResults> majorityVoteAllClassesRanked(FeatureSample featureSample) {
        HashMap newHashMap = Maps.newHashMap();
        double doubleValue = this.class1Weight.doubleValue();
        double doubleValue2 = this.class2Weight.doubleValue();
        Iterator<RandomForestTree> it = this.trees.iterator();
        while (it.hasNext()) {
            for (ProbResults probResults : it.next().classify(featureSample, this.numberOfClasses, doubleValue, doubleValue2)) {
                String mostLikelyClass = probResults.getMostLikelyClass();
                double predictionProbability = probResults.getPredictionProbability();
                if (newHashMap.get(mostLikelyClass) != null) {
                    newHashMap.put(mostLikelyClass, Double.valueOf(predictionProbability + ((Double) newHashMap.get(mostLikelyClass)).doubleValue()));
                } else {
                    newHashMap.put(mostLikelyClass, Double.valueOf(predictionProbability));
                }
            }
        }
        ArrayList newArrayList = Lists.newArrayList();
        for (Map.Entry entry : newHashMap.entrySet()) {
            newArrayList.add(new ProbResults((String) entry.getKey(), ((Double) entry.getValue()).doubleValue() / this.numTrees));
        }
        return PROB_RESULTS_ORDERING.immutableSortedCopy(newArrayList);
    }

    private ProbResults weightedVote(FeatureSample featureSample) {
        String str;
        Map<String, BigDecimal> calculateClassVotes = calculateClassVotes(featureSample);
        BigDecimal bigDecimal = MINUS_ONE;
        String str2 = "";
        for (Map.Entry<String, BigDecimal> entry : calculateClassVotes.entrySet()) {
            if (entry.getKey().equalsIgnoreCase(this.class1Label)) {
                BigDecimal multiply = entry.getValue().multiply(this.class1Weight);
                if (multiply.compareTo(bigDecimal) >= 0) {
                    str = entry.getKey();
                } else {
                    str = str2;
                    multiply = bigDecimal;
                }
                bigDecimal = multiply;
            } else {
                if (entry.getKey().equalsIgnoreCase(this.class2Label)) {
                    BigDecimal multiply2 = entry.getValue().multiply(this.class2Weight);
                    if (multiply2.compareTo(bigDecimal) > 0) {
                        str = entry.getKey();
                        bigDecimal = multiply2;
                    }
                }
                str = str2;
            }
            str2 = str;
        }
        return new ProbResults(str2, bigDecimal.doubleValue() / this.trees.size());
    }

    @Override // com.xiam.consia.ml_new.classifiers.Classifier
    public void buildClassifier(AttributeStore attributeStore, FeatureSampleStore featureSampleStore, AttributeSelection attributeSelection) {
        this.trees.clear();
        for (int i = 0; i < this.numTrees; i++) {
            RandomForestTree create = RandomForestTree.create(this.predictionType, i, this.traceTreeIteration);
            create.buildClassifier(attributeStore, getCurrentBag(featureSampleStore), this.predictionType, attributeSelection, this.maxTreeDepth, this.useRandomSplitting);
            addTreeToForest(create);
        }
    }

    @Override // com.xiam.consia.ml_new.classifiers.Classifier
    public ProbResults classify(FeatureSample featureSample) {
        return this.predictionType.isBinaryClassification() ? weightedVote(featureSample) : majorityVote(featureSample);
    }

    @Override // com.xiam.consia.ml_new.classifiers.Classifier
    public List<ProbResults> classifyRanked(FeatureSample featureSample) {
        return majorityVoteAllClassesRanked(featureSample);
    }

    @Override // com.xiam.consia.ml_new.classifiers.Classifier
    public void deserialise(final ZipFile zipFile, final ClassifierConstants.PredictionType predictionType) throws IOException {
        long currentTimeMillis = System.currentTimeMillis();
        this.trees.clear();
        ListeningExecutorService createThreadPool = ModelThreadPool.createThreadPool(this.numDeserialisationThreads);
        Iterator it = Collections.list(zipFile.entries()).iterator();
        while (it.hasNext()) {
            final ZipEntry zipEntry = (ZipEntry) it.next();
            createThreadPool.submit((Callable) new Callable<Void>() { // from class: com.xiam.consia.ml_new.classifiers.RandomForest.1
                @Override // java.util.concurrent.Callable
                public Void call() throws Exception {
                    RandomForestTree create = RandomForestTree.create(predictionType, -1, RandomForest.this.traceTreeIteration);
                    create.deserialise(RandomForest.this.attributeNameSerialiser, new DataInputStream(new BufferedInputStream(zipFile.getInputStream(zipEntry))));
                    RandomForest.this.addTreeToForest(create);
                    return null;
                }
            });
        }
        createThreadPool.shutdown();
        try {
            if (!createThreadPool.awaitTermination(this.modelLoadTimeoutMins, TimeUnit.MINUTES)) {
                logger.w("RandomForest.deserialise(): deserialising model timed out before completion", new Object[0]);
            } else {
                logger.d("RandomForest.deserialise(): Time to deserialise models: %d ms", Long.valueOf(System.currentTimeMillis() - currentTimeMillis));
            }
        } catch (InterruptedException e) {
            logger.w("RandomForest.deserialise(): Interrupted while deserialising model.", e);
        }
    }

    public long getNumTrees() {
        return this.trees.size();
    }

    public List<RandomForestTree> getTrees() {
        return this.trees;
    }

    @Override // com.xiam.consia.ml_new.classifiers.Classifier
    public void loadProperties(ClassifierConstants.PredictionType predictionType, PropertyManager propertyManager) {
        this.class1Label = propertyManager.getStringProperty(PropertyConstants.PREDICT_ML_CLASS1_LABEL);
        this.class2Label = propertyManager.getStringProperty(PropertyConstants.PREDICT_ML_CLASS2_LABEL);
        if (predictionType == ClassifierConstants.PredictionType.APP) {
            this.class1Weight = asBigDecimal(propertyManager.getStringProperty(PropertyConstants.PREDICT_RF_CLASS1_WEIGHT_APP));
            this.class2Weight = asBigDecimal(propertyManager.getStringProperty(PropertyConstants.PREDICT_RF_CLASS2_WEIGHT_APP));
            this.numTrees = propertyManager.getLongProperty(PropertyConstants.PREDICT_RF_NUM_TREES_APP);
            this.minNumRecordsPerClass = propertyManager.getLongProperty(PropertyConstants.PREDICT_RF_MIN_NUM_RECORDS_PER_CLASS_APP);
            this.maxNumRecordsPerClass = propertyManager.getLongProperty(PropertyConstants.PREDICT_RF_MAX_NUM_RECORDS_PER_CLASS_APP);
        } else if (predictionType == ClassifierConstants.PredictionType.PHONEON) {
            this.class1Weight = asBigDecimal(propertyManager.getStringProperty(PropertyConstants.PREDICT_RF_CLASS1_WEIGHT_PHONEON));
            this.class2Weight = asBigDecimal(propertyManager.getStringProperty(PropertyConstants.PREDICT_RF_CLASS2_WEIGHT_PHONEON));
            this.numTrees = propertyManager.getLongProperty(PropertyConstants.PREDICT_RF_NUM_TREES_PHONEON);
            this.minNumRecordsPerClass = propertyManager.getLongProperty(PropertyConstants.PREDICT_RF_MIN_NUM_RECORDS_PER_CLASS_PHONEON);
            this.maxNumRecordsPerClass = propertyManager.getLongProperty(PropertyConstants.PREDICT_RF_MAX_NUM_RECORDS_PER_CLASS_PHONEON);
        } else if (predictionType == ClassifierConstants.PredictionType.PLACE) {
            this.numTrees = propertyManager.getLongProperty(PropertyConstants.PREDICT_RF_NUM_TREES_PLACE);
        } else if (predictionType == ClassifierConstants.PredictionType.MOST_LIKELY_APPS) {
            this.numTrees = propertyManager.getLongProperty(PropertyConstants.PREDICT_RF_NUM_TREES_MOST_LIKELY_APPS);
        } else if (predictionType == ClassifierConstants.PredictionType.CONTACTS) {
            this.numTrees = propertyManager.getLongProperty(PropertyConstants.PREDICT_RF_NUM_TREES_CONTACTS);
        } else if (predictionType == ClassifierConstants.PredictionType.PLACEMOVE) {
            this.class1Weight = asBigDecimal(propertyManager.getStringProperty(PropertyConstants.PREDICT_RF_CLASS1_WEIGHT_PLACEMOVE));
            this.class2Weight = asBigDecimal(propertyManager.getStringProperty(PropertyConstants.PREDICT_RF_CLASS2_WEIGHT_PLACEMOVE));
            this.numTrees = propertyManager.getLongProperty(PropertyConstants.PREDICT_RF_NUM_TREES_PLACEMOVE);
            this.minNumRecordsPerClass = propertyManager.getLongProperty(PropertyConstants.PREDICT_RF_MIN_NUM_RECORDS_PER_CLASS_PLACEMOVE);
            this.maxNumRecordsPerClass = propertyManager.getLongProperty(PropertyConstants.PREDICT_RF_MAX_NUM_RECORDS_PER_CLASS_PLACEMOVE);
        } else if (predictionType == ClassifierConstants.PredictionType.BATTERYCHARGE) {
            this.class1Weight = asBigDecimal(propertyManager.getStringProperty(PropertyConstants.PREDICT_RF_CLASS1_WEIGHT_BATTERYCHARGE));
            this.class2Weight = asBigDecimal(propertyManager.getStringProperty(PropertyConstants.PREDICT_RF_CLASS2_WEIGHT_BATTERYCHARGE));
            this.numTrees = propertyManager.getLongProperty(PropertyConstants.PREDICT_RF_NUM_TREES_BATTERYCHARGE);
            this.minNumRecordsPerClass = propertyManager.getLongProperty(PropertyConstants.PREDICT_RF_MIN_NUM_RECORDS_PER_CLASS_BATTERYCHARGE);
            this.maxNumRecordsPerClass = propertyManager.getLongProperty(PropertyConstants.PREDICT_RF_MAX_NUM_RECORDS_PER_CLASS_BATTERYCHARGE);
        } else if (predictionType == ClassifierConstants.PredictionType.BATTERYDRAIN) {
            this.numTrees = propertyManager.getLongProperty(PropertyConstants.PREDICT_RF_NUM_TREES_BATTERYDRAIN);
        } else if (predictionType == ClassifierConstants.PredictionType.BATTERYCHARGEDURATION) {
            this.class1Weight = asBigDecimal(propertyManager.getStringProperty(PropertyConstants.PREDICT_RF_CLASS1_WEIGHT_BATTERYCHARGEDURATION));
            this.class2Weight = asBigDecimal(propertyManager.getStringProperty(PropertyConstants.PREDICT_RF_CLASS2_WEIGHT_BATTERYCHARGEDURATION));
            this.numTrees = propertyManager.getLongProperty(PropertyConstants.PREDICT_RF_NUM_TREES_BATTERYCHARGEDURATION);
            this.minNumRecordsPerClass = propertyManager.getLongProperty(PropertyConstants.PREDICT_RF_MIN_NUM_RECORDS_PER_CLASS_BATTERYCHARGEDURATION);
            this.maxNumRecordsPerClass = propertyManager.getLongProperty(PropertyConstants.PREDICT_RF_MAX_NUM_RECORDS_PER_CLASS_BATTERYCHARGEDURATION);
        }
        this.randomSeed = propertyManager.getLongProperty(PropertyConstants.PREDICT_ML_RANDOM_SEED);
        this.randomGenerator = new Random(this.randomSeed);
        this.bagPercentSize = propertyManager.getLongProperty(PropertyConstants.PREDICT_RF_BAG_PERCENT_SIZE);
        this.useRandomSplitting = propertyManager.getStringProperty(PropertyConstants.PREDICT_ML_RANDOM_ATTRIBUE_SPLITTER).equalsIgnoreCase("true");
        this.traceTreeIteration = propertyManager.getBooleanProperty(PropertyConstants.ML_OUTPUT_CLASSIFICATION_DEBUG);
    }

    public void loadPropertiesFromAnotherForest(RandomForest randomForest) {
        this.class1Label = randomForest.class1Label;
        this.class2Label = randomForest.class2Label;
        this.class1Weight = randomForest.class1Weight;
        this.class2Weight = randomForest.class2Weight;
        this.randomSeed = randomForest.randomSeed;
        this.randomGenerator = new Random(this.randomSeed);
        this.numTrees = randomForest.numTrees;
        this.bagPercentSize = randomForest.bagPercentSize;
        this.minNumRecordsPerClass = randomForest.minNumRecordsPerClass;
        this.maxNumRecordsPerClass = randomForest.maxNumRecordsPerClass;
    }

    @Override // com.xiam.consia.ml_new.classifiers.Classifier
    public String printStats() {
        String str = "";
        int i = 0;
        while (i < this.trees.size()) {
            String str2 = str + " " + this.trees.get(i).printStats();
            i++;
            str = str2;
        }
        return str;
    }

    public ProbResults regress(FeatureSample featureSample) {
        ArrayList arrayList = new ArrayList();
        Iterator<RandomForestTree> it = this.trees.iterator();
        while (it.hasNext()) {
            arrayList.addAll(it.next().regressValues(featureSample));
        }
        return new ProbResults(Double.toString(calcMean(arrayList)), 1.0d);
    }

    @Override // com.xiam.consia.ml_new.classifiers.Classifier
    public void serialise(OutputStream outputStream) throws IOException {
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        ByteArrayOutputStream byteArrayOutputStream2 = new ByteArrayOutputStream();
        ZipOutputStream zipOutputStream = new ZipOutputStream(new BufferedOutputStream(byteArrayOutputStream));
        int i = 0;
        while (true) {
            int i2 = i;
            try {
                if (i2 >= this.numTrees) {
                    zipOutputStream.close();
                    outputStream.write(byteArrayOutputStream.toByteArray());
                    return;
                }
                RandomForestTree randomForestTree = this.trees.get(i2);
                zipOutputStream.putNextEntry(new ZipEntry("RandomForestTree" + i2));
                byteArrayOutputStream2.reset();
                randomForestTree.serialise(this.attributeNameSerialiser, new DataOutputStream(byteArrayOutputStream2));
                zipOutputStream.write(byteArrayOutputStream2.toByteArray());
                zipOutputStream.closeEntry();
                i = i2 + 1;
            } catch (Throwable th) {
                zipOutputStream.close();
                throw th;
            }
        }
    }
}
